""" Buffer Implemntation """
import os
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Type, Optional, Callable

import pickle as pkl
import numpy as np
import jax
import jax.numpy as jnp
from gym import spaces

from sb3_jax.common.buffers import BaseBuffer
from diffgro.utils.utils import print_r, print_b


def load_traj(path: str):
    if not os.path.exists(path):
        print_r(f"File not exist at {path}")
        exit()
    with open(path, 'rb') as f:
        traj = pkl.load(f)
    return traj


def convert_traj(trajectories: List[Dict[str, np.ndarray]]) -> None:
    # convert trajectory into skil unit
    skill_trajectories = []
    for path in trajectories:
        skill, trajectory = None, None
        for t in range(len(path['actions'])):
            if skill != path['skill_langs'][t]:
                if trajectory is not None:
                    trajectory["observations"].append(path["observations"][t]) # add last obs
                    for key, value in trajectory.items():
                        trajectory[key] = np.array(value)
                    skill_trajectories.append(trajectory) # concat trajectory

                trajectory = {"observations": [], "actions": [],
                    "skill_langs": [], "skill_embds": [],
                    "rewards": [], "terminals": [], "infos": []}
                skill = path['skill_langs'][t]
            
            trajectory["observations"].append(path["observations"][t])
            trajectory["actions"].append(path["actions"][t])
            trajectory["skill_langs"].append(path["skill_langs"][t])
            trajectory["skill_embds"].append(path["skill_embds"][t])
            trajectory["rewards"].append(path["rewards"][t])
            trajectory["terminals"].append(path["terminals"][t])
            trajectory["infos"].append(path["infos"][t])
    return skill_trajectories


class TrajectoryBufferSamples(NamedTuple):
    tasks: np.ndarray
    observations: np.ndarray
    actions: np.ndarray
    skills: np.ndarray
    next_observations: np.ndarray
    rewards: np.ndarray
    dones: np.ndarray
    timesteps: np.ndarray
    masks: np.ndarray


class TrajectoryBuffer(BaseBuffer):
    def __init__(
        self,
        trajectories: List[Dict[str, np.ndarray]],
        max_length: int,
        convert: bool = True,
        observation_space: spaces.Space = None,
        action_space: spaces.Space = None,
    ):
        super(TrajectoryBuffer, self).__init__(None, observation_space, action_space)
        self.trajectories = trajectories
        if convert: self.trajectories = convert_traj(trajectories)
        self.max_length = max_length
        self.setup()

    def setup(self) -> None:
        self.inds = {}
        self.tasks = {}
        self.observations = {}
        self.actions = {}
        self.skills = {} # skill embeddings
        self.rewards = {}
        self.next_observations = {}
        self.dones = {}
        self.timesteps = {}
        self.masks = {}
        
        # get skill dimension
        self.skill_dim = self.trajectories[0]['skill_embds'][0].shape[-1]
        # trajectory lengths
        observations, actions, traj_lengths = [], [], []
        for path in self.trajectories:
            observations.append(path['observations'])
            actions.append(path['actions'])
            traj_lengths.append(len(path['observations']))
        self.traj_lengths = np.array(traj_lengths)
        observations, actions = np.concatenate(observations, axis=0), np.concatenate(actions, axis=0)
        self.obs_mean, self.obs_std = np.mean(observations, axis=0), np.std(observations, axis=0) + 1e-6
        self.act_mean, self.act_std = np.mean(actions, axis=0), np.std(actions, axis=0) + 1e-6

        self._set_inds(self.max_length)
        
    def _set_inds(self, max_length: int) -> None:
        # trajectory indexes
        inds = []
        for i, traj_length in enumerate(self.traj_lengths):
            max_start = traj_length - max_length
            for start in range(max_start):
                end = start + max_length
                inds.append((i, start, end))
        inds = np.array(inds)
        self.inds[max_length] = inds
        
        tasks = np.zeros((len(inds), self.skill_dim), dtype=np.float32)
        observations = np.zeros((len(inds), max_length) + self.obs_shape, dtype=np.float32)
        actions = np.zeros((len(inds), max_length, self.act_dim), dtype=np.float32)
        skills = np.zeros((len(inds), self.skill_dim), dtype=np.float32) 
        next_observations = np.zeros((len(inds), max_length) + self.obs_shape, dtype=np.float32)
        rewards = np.zeros((len(inds), max_length, 1), dtype=np.float32)
        dones = np.zeros((len(inds), max_length), dtype=np.float32)
        timesteps = np.zeros((len(inds), max_length), dtype=np.int32)
        masks = np.ones((len(inds), max_length), dtype=np.int32) 
    
        for i, (ind, si, en) in enumerate(inds):
            traj = self.trajectories[ind]
            tasks[i] = traj['task'].reshape(1, -1, self.skill_dim)
            observations[i] = traj['observations'][si:en].reshape(1, -1, self.obs_dim) 
            #observations[i] = (observations[i] - self.obs_mean) / self.obs_std
            next_observations[i] = traj['observations'][si+1:en+1].reshape(1, -1, self.obs_dim)
            #next_observations[i] = (next_observations[i] - self.obs_mean) / self.obs_std
            actions[i] = traj['actions'][si:en].reshape(1, -1, self.act_dim)
            skills[i] = traj['skill_embds'][si].reshape(1, self.skill_dim)
            rewards[i] = traj['rewards'][si:en].reshape(1, -1, 1)
            if 'terminals' in traj.keys(): dones[i] = traj['terminals'][si:en].reshape(1, -1)
            else: dones[i] = traj['dones'][si:en].reshape(1, -1)
            timesteps[i] = np.arange(si, en).reshape(1, -1)
        
        self.tasks[max_length] = tasks
        self.observations[max_length] = observations
        self.actions[max_length] = actions
        self.skills[max_length] = skills
        self.next_observations[max_length] = next_observations
        self.rewards[max_length] = rewards
        self.dones[max_length] = dones
        self.timesteps[max_length] = timesteps
        self.masks[max_length] = masks

    def sample(self, batch_keys: List[str], batch_size: int, max_length: int = None) -> TrajectoryBufferSamples:
        if max_length is None: max_length = self.max_length
        if max_length not in self.inds.keys():
            self._set_inds(max_length)
        inds = self.inds[max_length]

        batch_inds = np.random.choice(
            np.arange(len(inds)),
            size=batch_size,
            replace=True,
        )
        return self._get_samples(batch_keys, batch_inds, max_length)

    def _get_samples(self, batch_keys: List[str], batch_inds: np.ndarray, max_length: int = None) -> TrajectoryBufferSamples:
        data = (
            self.tasks[max_length][batch_inds,:] if 'tasks' in batch_keys else None,
            self.observations[max_length][batch_inds,:,:] if 'observations' in batch_keys else None,
            self.actions[max_length][batch_inds,:,:] if 'actions' in batch_keys else None,
            self.skills[max_length][batch_inds,:] if 'skills' in batch_keys else None,
            self.next_observations[max_length][batch_inds,:,:] if 'next_observations' in batch_keys else None,
            self.rewards[max_length][batch_inds,:] if 'rewards' in batch_keys else None,
            self.dones[max_length][batch_inds,:] if 'dones' in batch_keys else None,
            self.timesteps[max_length][batch_inds,:] if 'timesteps'in batch_keys else None,
            self.masks[max_length][batch_inds,:] if 'masks' in batch_keys else None,
        )
        return TrajectoryBufferSamples(*tuple(data))
  

class MTTrajectoryBuffer(BaseBuffer):
    def __init__(
        self,
        max_length: int,
        convert: bool = True,
        observation_space: spaces.Space = None,
        action_space: spaces.Space = None,
    ):
        super(MTTrajectoryBuffer, self).__init__(None, observation_space, action_space)
        self.max_length = max_length
        self.convert = convert

        self._buffers = []
        self.obs_means, self.obs_stds = [], []

    @property
    def buffers(self):
        return self._buffers

    def sample(self, batch_keys: List[str], batch_size: int, max_length: int = None) -> TrajectoryBufferSamples:
        batch_size = int(batch_size / len(self.buffers))
        tasks, observations, actions, skills, next_observations, rewards, dones, timesteps, masks = [], [], [], [], [], [], [], [], []
        
        for buff in self.buffers:
            samples = buff.sample(batch_keys, batch_size, max_length)
            tasks.append(samples.tasks)
            observations.append(samples.observations)
            actions.append(samples.actions)
            skills.append(samples.skills)
            next_observations.append(samples.next_observations)
            rewards.append(samples.rewards)
            dones.append(samples.dones)
            timesteps.append(samples.timesteps)
            masks.append(samples.masks)

        data = (
            np.concatenate(tasks, axis=0) if 'tasks' in batch_keys else None,
            np.concatenate(observations, axis=0) if 'observations' in batch_keys else None,
            np.concatenate(actions, axis=0) if 'actions' in batch_keys else None,
            np.concatenate(skills, axis=0) if 'skills' in batch_keys else None,
            np.concatenate(next_observations, axis=0) if 'next_observations' in batch_keys else None,
            np.concatenate(rewards, axis=0) if 'rewards' in batch_keys else None,
            np.concatenate(dones, axis=0) if 'dones' in batch_keys else None,
            np.concatenate(timesteps, axis=0) if 'timesteps' in batch_keys else None,
            np.concatenate(masks, axis=0) if 'masks' in batch_keys else None,
        )
        return TrajectoryBufferSamples(*tuple(data))

    def _get_samples(self, batch_inds: np.ndarray) -> TrajectoryBufferSamples:
        raise NotImplementedError

    def add_task(self, trajectories: List[Dict[str, np.ndarray]]) -> None:
        buffer = TrajectoryBuffer(trajectories, self.max_length, self.convert, self.observation_space, self.action_space)
        self.buffers.append(buffer)
        
        self.skill_dim = buffer.skill_dim
        self.obs_means.append(self.buffers[-1].obs_mean)
        self.obs_stds.append(self.buffers[-1].obs_std)


class PredictorBufferSamples(NamedTuple):
    start_observations: np.ndarray
    observations: np.ndarray
    actions: np.ndarray
    skills: np.ndarray
    dones: np.ndarray


class PredictorBuffer(BaseBuffer):
    def __init__(
        self,
        trajectories: List[Dict[str, np.ndarray]],
        max_length: int,
        observation_space: spaces.Space = None,
        action_space: spaces.Space = None,
    ):
        super(PredictorBuffer, self).__init__(None, observation_space, action_space)
        self.trajectories = convert_traj(trajectories)
        self.max_length = max_length
        self.setup()
   
    def setup(self) -> None:
        # get skill dimension
        self.skill_dim = self.trajectories[0]['skill_embds'][0].shape[-1]
        # trajectory lengths
        observations, actions, traj_lengths = [], [], []
        for path in self.trajectories:
            observations.append(path['observations'])
            actions.append(path['actions'])
            traj_lengths.append(len(path['observations']))
        self.traj_lengths = np.array(traj_lengths)
        observations, actions = np.concatenate(observations, axis=0), np.concatenate(actions, axis=0)
        self.obs_mean, self.obs_std = np.mean(observations, axis=0), np.std(observations, axis=0) + 1e-6
        self.act_mean, self.act_std = np.mean(actions, axis=0), np.std(actions, axis=0) + 1e-6

        # trajectory indexes
        x_inds, y_inds = [], []
        for i, traj_length in enumerate(self.traj_lengths):
            max_start = traj_length - self.max_length
            for start in range(max_start):
                end = start + self.max_length
                done = 1 if end == (traj_length - 1) else 0
                if not done: x_inds.append((i, start, end, 0))
                else: y_inds.append((i, start, end, 1))
        self.x_inds = np.array(x_inds)
        self.y_inds = np.array(y_inds)
        
        self.x_start_observations, self.x_observations, self.x_actions, self.x_skills, self.x_dones = self._set_inds(x_inds, self.max_length)
        self.y_start_observations, self.y_observations, self.y_actions, self.y_skills, self.y_dones = self._set_inds(y_inds, self.max_length)
    
    def _set_inds(self, inds: np.ndarray, max_length: int) -> List[np.ndarray]:
        # not done inds
        start_observations = np.zeros((len(inds),) + self.obs_shape, dtype=np.float32)
        observations = np.zeros((len(inds), max_length) + self.obs_shape, dtype=np.float32)
        actions = np.zeros((len(inds), max_length, self.act_dim), dtype=np.float32)
        skills = np.zeros((len(inds), self.skill_dim), dtype=np.float32)
        dones = np.zeros((len(inds), 1), dtype=np.float32)

        for i, (ind, si, en, done) in enumerate(inds):
            traj = self.trajectories[ind]
            start_observations[i] = traj['observations'][0]
            # start_observations[i] = (start_observations[i] - self.obs_mean) / self.obs_std
            observations[i] = traj['observations'][si:en].reshape(1, -1, self.obs_dim)
            actions[i] = traj['actions'][si:en].reshape(1, -1, self.act_dim)
            # observations[i] = (observations[i] - self.obs_mean) / self.obs_std
            skills[i] = traj['skill_embds'][0]
            dones[i] = done 
        return start_observations, observations, actions, skills, dones
    
    def sample(self, batch_size: int) -> PredictorBufferSamples:
        x_batch_inds = np.random.choice(
            np.arange(len(self.x_inds)),
            size=int(batch_size/2),
            replace=True,
        )
        y_batch_inds = np.random.choice(
            np.arange(len(self.y_inds)),
            size=int(batch_size/2),
            replace=True,
        )
        return self._get_samples(x_batch_inds, y_batch_inds)

    def _get_samples(self, x_batch_inds: np.ndarray, y_batch_inds: np.ndarray) -> PredictorBufferSamples:
        data = (
            np.concatenate((self.x_start_observations[x_batch_inds], self.y_start_observations[y_batch_inds]), axis=0),
            np.concatenate((self.x_observations[x_batch_inds], self.y_observations[y_batch_inds]), axis=0),
            np.concatenate((self.x_actions[x_batch_inds], self.y_actions[y_batch_inds]), axis=0),
            np.concatenate((self.x_skills[x_batch_inds], self.y_skills[y_batch_inds]), axis=0),
            np.concatenate((self.x_dones[x_batch_inds], self.y_dones[y_batch_inds]), axis=0),
        )
        return PredictorBufferSamples(*tuple(data))
